package com.mical.demo.security;

import com.mical.demo.service.TokenBlacklistService;
import io.jsonwebtoken.*;
import io.jsonwebtoken.security.Keys;
import jakarta.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.userdetails.User;
import org.springframework.stereotype.Component;

import java.security.Key;
import java.util.Collection;
import java.util.Date;
import java.util.stream.Collectors;

@Component
public class JwtTokenUtil {
    private final TokenBlacklistService tokenBlacklistService;

    @Value("${jwt.secret}")
    private String secret;

    @Value("${jwt.expiration}")
    private long expiration; // in milliseconds

    private Key key;

    @Autowired
    public JwtTokenUtil(TokenBlacklistService tokenBlacklistService) {
        this.tokenBlacklistService = tokenBlacklistService;
    }

    @PostConstruct
    public void init() {
        this.key = Keys.hmacShaKeyFor(secret.getBytes());
    }

    public String generateToken(Authentication authentication) {
        String username = authentication.getName();
        Collection<? extends GrantedAuthority> authorities = authentication.getAuthorities();

        String authoritiesStr = authorities.stream()
                .map(GrantedAuthority::getAuthority)
                .collect(Collectors.joining(","));

        return Jwts.builder()
                .setSubject(username)
                .claim("auth", authoritiesStr)
                .setIssuedAt(new Date())
                .setExpiration(new Date(System.currentTimeMillis() + expiration))
                .signWith(key, SignatureAlgorithm.HS256)
                .compact();
    }

    public boolean validateToken(String token) {
        try {
            if (tokenBlacklistService.isBlacklisted(token)) {
                return false;
            }

            Jwts.parserBuilder()
                    .setSigningKey(key)
                    .build()
                    .parseClaimsJws(token);
            return true;
        } catch (ExpiredJwtException ex) {
            // Token过期
            return false;
        } catch (JwtException | IllegalArgumentException ex) {
            // 其他JWT异常
            return false;
        }
    }

    public Authentication getAuthentication(String token) {
        try {
            Claims claims = Jwts.parserBuilder()
                    .setSigningKey(key)
                    .build()
                    .parseClaimsJws(token)
                    .getBody();

            String username = claims.getSubject();
            String authoritiesStr = claims.get("auth", String.class);

            Collection<? extends GrantedAuthority> authorities =
                    authoritiesStr == null
                            ? AuthorityUtils.NO_AUTHORITIES
                            : AuthorityUtils.commaSeparatedStringToAuthorityList(authoritiesStr);

            User principal = new User(username, "", authorities);
            return new UsernamePasswordAuthenticationToken(principal, token, authorities);
        } catch (ExpiredJwtException ex) {
            // 即使过期也返回认证信息，让后续处理
            Claims claims = ex.getClaims();
            String username = claims.getSubject();
            String authoritiesStr = claims.get("auth", String.class);

            Collection<? extends GrantedAuthority> authorities =
                    authoritiesStr == null
                            ? AuthorityUtils.NO_AUTHORITIES
                            : AuthorityUtils.commaSeparatedStringToAuthorityList(authoritiesStr);

            User principal = new User(username, "", authorities);
            return new UsernamePasswordAuthenticationToken(principal, token, authorities);
        }
    }

    public String getUsernameFromToken(String token) {
        try {
            return Jwts.parserBuilder()
                    .setSigningKey(key)
                    .build()
                    .parseClaimsJws(token)
                    .getBody()
                    .getSubject();
        } catch (ExpiredJwtException ex) {
            return ex.getClaims().getSubject();
        }
    }

    public Date getExpirationDateFromToken(String token) {
        try {
            return Jwts.parserBuilder()
                    .setSigningKey(key)
                    .build()
                    .parseClaimsJws(token)
                    .getBody()
                    .getExpiration();
        } catch (ExpiredJwtException ex) {
            return ex.getClaims().getExpiration();
        }
    }

    public boolean isTokenExpired(String token) {
        try {
            Date expiration = getExpirationDateFromToken(token);
            return expiration.before(new Date());
        } catch (Exception e) {
            return true;
        }
    }

    public boolean isTokenExpiringSoon(String token, long thresholdMs) {
        try {
            Date expiration = getExpirationDateFromToken(token);
            return expiration.getTime() - System.currentTimeMillis() <= thresholdMs;
        } catch (Exception e) {
            return true;
        }
    }

    public void invalidateToken(String token) {
        Date expiry = getExpirationDateFromToken(token);
        long remainingMillis = expiry.getTime() - System.currentTimeMillis();

        if (remainingMillis > 0) {
            tokenBlacklistService.addToBlacklist(token, remainingMillis);
        }
    }
}